/**
 * \file dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#ifndef KERN_APPLY_OPR_OPR
#error "must define KERN_APPLY_OPR_OPR"
#endif

#include "src/rocm/utils.h.hip"
#include "./kern.h.hip"
#include "megdnn/internal/defs.h"
#include "megdnn/dtype.h"

using namespace megdnn;
using namespace rocm;
using namespace indexing_multi_axis_vec;

namespace {
    template<typename ctype, int ndim, class Opr>
    __global__ void kapply_opr(ApplyOprParam<ctype, ndim> param) {

        uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x;
        if (oidx < param.tot_size) {
            int offset = 0, coidx = oidx;
            int idx_flat = 0;
#pragma unroll
            for (int i = ndim - 1; i >= 0; -- i) {
                int next_coidx, ax_idx;
                if (i + 1 == param.idx_axis_end) {
                    idx_flat = coidx;
                }
                // may not trigger
                if (i + 1 == param.idx_axis) {
                    idx_flat -= coidx * param.idx_nelems;
                }
                if (i) {
                    next_coidx = coidx / param.value_ly_on_data.shape[i - 1];
                    ax_idx =
                        coidx -
                        (next_coidx *
                         param.value_ly_on_data.shape[i - 1].divisor());
                    coidx = next_coidx;
                } else {
                    ax_idx = coidx;
                }
                offset += param.value_ly_on_data.stride[i] * ax_idx;
            }
            offset += param.offset_base[idx_flat];
            Opr::apply(
                    param.data[offset],
                    param.value[oidx * param.value_stride]);
        }
    }
}

template<typename ctype, int ndim, class Opr>
void indexing_multi_axis_vec::apply_opr(
        const ApplyOprParam<ctype, ndim> &param, hipStream_t stream) {
    void (*kptr)(ApplyOprParam<ctype, ndim>) = kapply_opr<ctype, ndim, Opr>;
    int bsize = 256;
    hipLaunchKernelGGL(kptr,
                       DIVUP(param.tot_size, bsize), bsize, 0, stream,
                       param);
}

namespace megdnn {
namespace rocm {
namespace indexing_multi_axis_vec {

#define INST(_ndim, _ctype) \
    template void apply_opr<_ctype, _ndim, KERN_APPLY_OPR_OPR> \
    (const ApplyOprParam<_ctype, _ndim>&, hipStream_t);
#define cb(_dtype) \
    MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype)
    MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
    cb(::megdnn::dtype::Bool)
#undef cb
#undef INST

} // namespace indexing_multi_axis_vec
} // namespace rocm
} // namespace megdnn

// vim: ft=cuda syntax=cpp.doxygen

